#!/usr/bin/env python
"""
检查GPU可用性
"""
import torch

print("=" * 70)
print(" " * 20 + "GPU 可用性检查")
print("=" * 70)
print()

print(f"PyTorch 版本: {torch.__version__}")
print(f"CUDA 可用: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA 版本: {torch.version.cuda}")
    print(f"GPU 数量: {torch.cuda.device_count()}")
    print()
    
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  显存总量: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
    
    print()
    print("✓ 可以使用 CUDA GPU 加速！")
    print("  建议配置: device: 'cuda'")
else:
    print()
    print("⚠️ 未检测到可用的 CUDA GPU")
    print("  当前只能使用 CPU 模式")
    print("  如果您有 NVIDIA GPU，请确保:")
    print("    1. 安装了 NVIDIA 驱动")
    print("    2. 安装了 CUDA Toolkit")
    print("    3. 安装了 GPU 版本的 PyTorch")

print()
print("=" * 70)
